Stroke Risk Dataset¶

Viewing recommendations¶

For best experience open the pre-rendered .html file

You have 2 options depending on how much you trust my code and how little time you want to dedicate to this review :)

  1. ✅ open the prerendered HTML file, which contains all of the data without any of the rendering limitations that github suffers from
    • Best if you just want to review the analysis and conclusions You have two options for optimal viewing experience:
  2. ✅ open the notebook in jupyterlab (instead of github) and re-run all cells
    • Best if you want to verify that the code works and that the data is not fake

Why you should NOT visualize this notebook on github

Some charts/diagrams/features are not visible in github.

  • This impacts all:
    • plotly plots
    • folium maps
    • embedded images
    • and all other dynamic content
    • quarto meta-tags that improves visualization

This is standard and well-known behaviour.

  • While some workarounds could be possible, there are no universal fixes that resolve all the issues.
    • Plotly customization requires installing external tools, but doesn't fix the other issues
    • Standard workarounds (nbviewer) do not work because the github repo is private.

If you chose to run this locally, there are some prerequisites:

  • you will need python 3.9
  • you will need to install the dependencies using pip install -r requirements.txt before proceeding.

Context¶

The hospital where we work (The Johns Hopkins Hospital) has asked us to generate a simple self-assessment application that patients can use to understand their health risks.

In this first iteration, we will cover stroke risk, and if the experiment goes well, we can expand it to other cases in the future, after this project is finished.

In order to achieve, we need to build a system that can predict someone's risk of stroke given some info about them.

We have found a dataset online which we can use to train our model. For this exercise, let's assume that the dataset is of high quality and correctly identifies cases without errors, mistakes or issues. Since the data is sensitive, we're starting with a subset of the dataset to see if we can make predictions from a few datapoints.

As with everything health-related, we want to avoid false negatives, which would give the patient a false sense of security and might make them not seek help. We should also strive to not cause alarm unnecessarily (low false positives would also be good), but the primary metric to measure success should be a low false negative rate.

Exploratory Analysis¶

In [108]:
import ipywidgets as widgets
from IPython.display import display, Markdown, Image, clear_output, HTML
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from scipy.stats import chi2_contingency

import statsmodels.api as sm

from random import random, seed
import sqlite3 as lite
import logging
import warnings
import ydata_profiling
import iplantuml
import xml.dom.minidom

from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    accuracy_score,
    recall_score,
    precision_score,
)
from sklearn.preprocessing import MinMaxScaler, StandardScaler, FunctionTransformer
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.tree import plot_tree
from sklearn.metrics import make_scorer, confusion_matrix, PrecisionRecallDisplay
from sklearn.pipeline import Pipeline
from sklearn.impute import KNNImputer
from sklearn.pipeline import make_pipeline
from sklearn.compose import ColumnTransformer
from sklearn.base import BaseEstimator, TransformerMixin

import xgboost as xgb
import lightgbm as lgbm
from catboost import CatBoostClassifier
import shap

from imblearn.over_sampling import SMOTE
import joblib

from utils import *
from utils import __

# from analysis import *

from autofeat import AutoFeatRegressor as AFR
from sklearn.model_selection import train_test_split
from sklearn.exceptions import DataConversionWarning

import category_encoders as ce

seed(100)
pd.options.display.max_rows = 100
pd.options.display.max_colwidth = 50
util.check("done")
✅

Let's use black to auto-format all our cells so they adhere to PEP8

In [2]:
import lab_black

%reload_ext lab_black
util.patch_nb_black()
# fmt: off
# fmt: on
In [3]:
from sklearn import set_config

set_config(transform_output="pandas")
In [4]:
logger = util.configure_logging(jupyterlab_level=logging.WARN, file_level=logging.DEBUG)

warnings.filterwarnings("ignore", category=FutureWarning)

Fetching and loading the dataset¶

We're ready to start! Let's download the dataset from Kaggle.

In [5]:
dataset_name = "fedesoriano/stroke-prediction-dataset"
db_filename = "healthcare-dataset-stroke-data.csv"

auto_kaggle.download_dataset(dataset_name, db_filename)
__
Kaggle API 1.5.13 - login as 'edualmas'
File [dataset/healthcare-dataset-stroke-data.csv] already exists locally!
No need to re-download dataset [fedesoriano/stroke-prediction-dataset]
In [6]:
raw_ = pd.read_csv(f"dataset/{db_filename}", index_col=0)
In [7]:
raw_.columns
Out[7]:
Index(['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
       'work_type', 'Residence_type', 'avg_glucose_level', 'bmi',
       'smoking_status', 'stroke'],
      dtype='object')
In [8]:
raw_.columns = raw_.columns.str.lower()
In [9]:
sns.pairplot(raw_)
Out[9]:
<seaborn.axisgrid.PairGrid at 0x7fadf0743520>
In [10]:
raw_.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 5110 entries, 9046 to 44679
Data columns (total 11 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   gender             5110 non-null   object 
 1   age                5110 non-null   float64
 2   hypertension       5110 non-null   int64  
 3   heart_disease      5110 non-null   int64  
 4   ever_married       5110 non-null   object 
 5   work_type          5110 non-null   object 
 6   residence_type     5110 non-null   object 
 7   avg_glucose_level  5110 non-null   float64
 8   bmi                4909 non-null   float64
 9   smoking_status     5110 non-null   object 
 10  stroke             5110 non-null   int64  
dtypes: float64(3), int64(3), object(5)
memory usage: 479.1+ KB
In [11]:
raw_.describe().T
Out[11]:
count mean std min 25% 50% 75% max
age 5110.0 43.226614 22.612647 0.08 25.000 45.000 61.00 82.00
hypertension 5110.0 0.097456 0.296607 0.00 0.000 0.000 0.00 1.00
heart_disease 5110.0 0.054012 0.226063 0.00 0.000 0.000 0.00 1.00
avg_glucose_level 5110.0 106.147677 45.283560 55.12 77.245 91.885 114.09 271.74
bmi 4909.0 28.893237 7.854067 10.30 23.500 28.100 33.10 97.60
stroke 5110.0 0.048728 0.215320 0.00 0.000 0.000 0.00 1.00

ON MISSING DATA It seems that the dataset is almost full, except for some minor gaps around BMI.

Let's drop those few rows that are missing BMI so we have a dataset without NAs (4% of the data).

ON BMI AS A METRIC

According to literature, BMI should ideally be around 25. It is beyond the scope of this initial exploration to judge this particular score, but people who consider it useful MIGHT also consider that on average, this dataset might skew towards the heavier side/overweight bodies.

We hope that the models we build in this report will help us understand whether or not this score actually contributes in a statisically significant way to the prediction of strokes.

In [12]:
sns.histplot(raw_.gender)
Out[12]:
<AxesSubplot: xlabel='gender', ylabel='Count'>

💛🤍💜🖤 Gender is not binary, and this dataset earns a token inclusivity badge 🥇 for this single datapoint.

HOWEVER, The singular "Other" entry really gives us nothing to work with.

If there had been more data, we could have explored this a bit further, but as things are, we are pretty much forced to exclude this lone data point and to treat the "gender" column as if it were binary.

We will encode it as such:

  • 0 = Female
  • 1 = Male

With this encoding, it will treat "Female" as the default, and whatever coefficient the models assign to this feature will represent the additional risk if the person is male (=1). A positive value will indicate higher risk for male population, while a negative coefficient will indicate a lower risk (comparatively).

In [13]:
sns.histplot(raw_.ever_married)
Out[13]:
<AxesSubplot: xlabel='ever_married', ylabel='Count'>
In [14]:
raw_.smoking_status.value_counts()
Out[14]:
never smoked       1892
Unknown            1544
formerly smoked     885
smokes              789
Name: smoking_status, dtype: int64
In [15]:
sns.histplot(raw_.work_type)
Out[15]:
<AxesSubplot: xlabel='work_type', ylabel='Count'>

We will also drop "Never Worked" rows (0.4%), since there is not enough data to actually use it for predictions (without massive overfitting the few datapoints it will contain). We will also split this dataset into "private sector" and "not private sector" to try to balance it a bit.

In [16]:
sns.histplot(raw_.residence_type)
Out[16]:
<AxesSubplot: xlabel='residence_type', ylabel='Count'>

Everything else looks good and usable.

Let's do a bit of cleaning and remove rare values

In [17]:
def drop_rare_values(df: pd.DataFrame):
    df = df.drop(df[df["gender"].isin(["Other"])].index)
    df = df.drop(df[df["work_type"].isin(["Never_worked"])].index)
    return df


raw_ = drop_rare_values(raw_)
In [18]:
split = pipeline_utils.train_val_test_split(raw_, target_col="stroke", rnd=30)
split
Out[18]:
<TrainValTestData instance> 
Train = X:(3051, 10), y:(3051,)
Val = X:(1018, 10), y:(1018,)
Test = X:(1018, 10), y:(1018,)

Let's check that stratification worked

In [19]:
print(sum(split._y_train) / len(split._y_train))
print(sum(split._y_val) / len(split._y_val))
print(sum(split._y_test) / len(split._y_test))
0.04883644706653556
0.04911591355599214
0.04911591355599214

Excellent! all 3 datasets have the same % of stroke perdictions (1 vs 0s)

In [20]:
train_data = split.train_df()
In [21]:
train_data.head()
Out[21]:
gender age hypertension heart_disease ever_married work_type residence_type avg_glucose_level bmi smoking_status stroke
id
58761 Male 52.0 0 0 Yes Private Urban 87.51 30.5 formerly smoked 0
48368 Female 65.0 0 0 Yes Self-employed Rural 104.21 36.8 never smoked 0
8168 Female 34.0 0 0 Yes Private Rural 112.54 23.4 formerly smoked 0
44426 Female 21.0 0 0 Yes Private Urban 126.35 26.9 never smoked 0
24783 Female 28.0 0 0 No Private Urban 87.91 22.7 formerly smoked 0
In [22]:
ydata_utils.report(train_data)
Summarize dataset:   0%|          | 0/5 [00:00<?, ?it/s]
Generate report structure:   0%|          | 0/1 [00:00<?, ?it/s]
Render HTML:   0%|          | 0/1 [00:00<?, ?it/s]
Out[22]:

We're particularly interested in the correlation matrix. It seems that the strongest correlation with "stroke" is "age".

Let's review the other insights from this automated report.

Interesting insights uncovered:

Some columns are highly imbalanced

  • hypertension: +50% imbalance
  • heart_disease: ~70% imbalance
  • stroke: ~70% imbalance

These scores are calculated using 1 - (entropy(value_counts, base=2) / log2(n_classes)) source, but for human readable values, let's check one of the examples:

  • stroke is scored with a 71.9% but:
    • 5% of the rows have "1"
    • 95% of the rows have "0"

These raw numbers paint a much grimmer picture in term of the imbalance.

We will need to be very careful when trying to create models for this, as a dummy model that always responded with "no stroke" would get a 95% accuracy score without doing any of the work.

We will aim for better RECALL, because those mistakes would be more costly.

In [23]:
train_data
Out[23]:
gender age hypertension heart_disease ever_married work_type residence_type avg_glucose_level bmi smoking_status stroke
id
58761 Male 52.0 0 0 Yes Private Urban 87.51 30.5 formerly smoked 0
48368 Female 65.0 0 0 Yes Self-employed Rural 104.21 36.8 never smoked 0
8168 Female 34.0 0 0 Yes Private Rural 112.54 23.4 formerly smoked 0
44426 Female 21.0 0 0 Yes Private Urban 126.35 26.9 never smoked 0
24783 Female 28.0 0 0 No Private Urban 87.91 22.7 formerly smoked 0
... ... ... ... ... ... ... ... ... ... ... ...
65962 Male 50.0 0 0 Yes Private Urban 58.70 38.9 smokes 0
19584 Female 20.0 0 0 No Private Urban 84.62 19.7 smokes 0
66051 Male 43.0 0 0 Yes Self-employed Rural 115.79 31.8 Unknown 0
58477 Female 45.0 0 0 Yes Private Urban 81.24 37.0 never smoked 0
64849 Female 42.0 0 0 Yes Private Urban 92.20 34.2 Unknown 0

3051 rows × 11 columns

In [24]:
age_group = train_data.groupby("age")
age = age_group.count().index.astype("float")

age_stroke = age_group.stroke.agg("sum")
age_total = age_group.stroke.agg("count")
percent = 100 * (age_stroke / age_total).astype("float")

sns.lineplot(
    x=age,
    y=percent,
)
plt.axvline(x=37, color="orange", linestyle=":", alpha=0.3)
plt.title("stroke risk over age (% of cases)")
plt.ylabel("% of cases")
Out[24]:
Text(0, 0.5, '% of cases')

It seems that there are very rare cases of stroke before ~40 year olds.

Since this dataset is highly unbalanced, we could have a model that accounts for this and specialises in people who are +40 years old.

If we decide to make a model that only provides predictions for +40s, we should have a UI that still accounts for this and lets the user know about the risks of stroke instead of just saying "you are not at risk", always.

This is beyond the scope of this modelling exercise (notebook).

This simple example shows the importance of having interdisciplinary teams that can own a product end-to-end instead of having knowledge silos (one team for ML/AI/DS, and another team for frontends...), this tiny details could fall through the cracks and end up having terrible consequences (false "not at risk" expectations for end users)

In [25]:
f, ax = plt.subplots(2, 5, figsize=(18, 8))
cols = set(train_data.columns) - {"stroke"}
for i, c in enumerate(cols):
    axx = ax[i // 5, i % 5]
    sns.histplot(train_data, x=c, hue="stroke", multiple="stack", ax=axx)
    chart_utils.rotate_x_labels(axx)

plt.tight_layout()

2 interesting things to highlight from these charts:

  • AGE: The only obvious predictor seem to be "Age". All the other ones seem to follow similar proportions.
  • EVER_MARRIED: The strong presence of "Stroke" in "Ever Married = Yes". But this could be caused by lots of factors: Survivor's bias being one of the more obvious ones. Whatever it is, we should NOT take conclusions on causality from such a simple dataset.
In [26]:
train_data.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 3051 entries, 58761 to 64849
Data columns (total 11 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   gender             3051 non-null   object 
 1   age                3051 non-null   float64
 2   hypertension       3051 non-null   int64  
 3   heart_disease      3051 non-null   int64  
 4   ever_married       3051 non-null   object 
 5   work_type          3051 non-null   object 
 6   residence_type     3051 non-null   object 
 7   avg_glucose_level  3051 non-null   float64
 8   bmi                2929 non-null   float64
 9   smoking_status     3051 non-null   object 
 10  stroke             3051 non-null   int64  
dtypes: float64(3), int64(3), object(5)
memory usage: 350.6+ KB

We see that one of the columns is missing some data (BMI)

Let's make a list of all the cleaning tasks we want to apply to our data:

The preparation tasks will depend on our algorithm chosen to create our final model.

  • XGBoost/LightGBM

    • [x] Encoding Categorical Features
    • [x] Fill in NAs in BMI using their 3 Nearest Neighbors.
    • [x] Standardization
  • CatBoost

    • [x] Identify categorical columns

We will require ML-Algorithm-agnostic transformation steps that can be applied in whatever pipelines are needed. These transformations should be general enough to allow reuse (whenever possible) across situations, and not be strongly-coupled to a specific algorithm (unless unavoidable).

Statistical Inferencial Analysis¶

Age vs Stroke¶

It seems that age is one of the principal drivers that can help us predict stroke.

Let's try to check this hypothesis, from the data that we have.

We can compare the average stroke % for people above/below 40 years old, and see if the difference is statistically significant enough to reject the null hypothesis.

  • Null hypothesis: There is not significant difference between people above 40 years old, and people below.
  • Alternative hypothesis: There is a significant difference.
  • Population: All people (considering the data is from the WHO, we presume the data is a subset that is representative of a broad subset of the population). This is speculation, the dataset does not include any specifics.
  • Significance: 0.05
In [27]:
data = split.train_df()

case_a = data[data.age <= 40].stroke
case_b = data[data.age > 40].stroke
contingency = np.array([case_a.value_counts().values, case_b.value_counts().values])
contingency
Out[27]:
array([[1342,    4],
       [1560,  145]])
In [28]:
chi2, p, dof, expected = chi2_contingency(contingency)
p
Out[28]:
3.7992380192189187e-25

Since the p-value is much smaller than our significance, we can reject the null hypothesis.

Note of importance. This analysis makes no claim in terms of causality.

Since the dataset does not specify how this data was collected, we cannot rule out any possibilities around causality. Maybe Age increases the risk of stroke, maybe only survivors of stroke make it to older ages (and only those made it to the study that led to this dataset), maybe there are confounding variables that link those two, etc...

BMI vs Stroke¶

BMI is often claimed to be a risk factor to numerous health complications.

Let's compare the average stroke % for people above/below a BMI score of 25 (considered the threshold for overweight in some studies), and see if the difference is statistically significant enough to reject the null hypothesis.

  • Null hypothesis: There is not significant difference between people above/below BMI of 25 Kg/m2.
  • Alternative hypothesis: There is a significant difference.
  • Population: All people (considering the data is from the WHO, we presume the data is a subset that is representative of a broad subset of the population). This is speculation, the dataset does not include any specifics.
  • Significance: 0.05
In [29]:
data = split.train_df()

case_a = data[data.bmi <= 25].stroke
case_b = data[data.age > 25].stroke
contingency = np.array([case_a.value_counts().values, case_b.value_counts().values])
contingency
Out[29]:
array([[ 944,   22],
       [2146,  147]])
In [30]:
chi2, p, dof, expected = chi2_contingency(contingency)
p
Out[30]:
1.8125172435548476e-06

Since the p-value is much smaller than our significance, we can reject the null hypothesis.

Same as above, this simple test does not give us enough answers to confidently make claims around causality.

Marriage Status vs Stroke¶

The old addage of "married people live longer" will be put to the test, today, using one of the least comprehensive datasets in existance.

  • Null hypothesis: There is not significant difference between people who married once or more times in their lifetime (there is no requirement around "minimum duration of the marriage" in order to qualify as "ever married").
  • Alternative hypothesis: There is a significant difference.
  • Population: All people (considering the data is from the WHO, we presume the data is a subset that is representative of a broad subset of the population). This is speculation, the dataset does not include any specifics.
  • Significance: 0.05
In [31]:
data = split.train_df()

case_a = data[data["ever_married"] == "Yes"].stroke
case_b = data[data["ever_married"] != "Yes"].stroke
contingency = np.array([case_a.value_counts().values, case_b.value_counts().values])
contingency
Out[31]:
array([[1873,  130],
       [1029,   19]])
In [32]:
chi2, p, dof, expected = chi2_contingency(contingency)
p
Out[32]:
2.09529661607514e-08

Since the p-value is much smaller than our significance, we can reject the null hypothesis.

Same as above, this simple test does not give us enough answers to confidently make claims around causality.

Preparing data for XGBoost/LightGBM¶

In [33]:
from sklearn.preprocessing import FunctionTransformer
In [34]:
def remove_whitespace_values(df: pd.DataFrame):
    for col in df.select_dtypes(include=["category"]).columns:
        df[col].cat.rename_categories(lambda x: x.replace(" ", "_"), inplace=True)

    for col in df.select_dtypes(include=["object"]).columns:
        df[col].replace(" ", "_", regex=True, inplace=True)
    return df


def columns_to_lowercase(df: pd.DataFrame):
    df.columns = df.columns.str.lower()
    return df


def convert_types(df):
    df_out = df.copy()
    for col in df_out.columns:
        if pd.api.types.is_integer_dtype(df_out[col].dtype):
            df_out[col] = df_out[col].astype("int64")
        elif pd.api.types.is_float_dtype(df_out[col].dtype):
            df_out[col] = df_out[col].astype("float64")
        elif pd.api.types.is_bool_dtype(df_out[col].dtype):
            df_out[col] = df_out[col].astype("bool")
    return df_out


class ColumnDropper(BaseEstimator, TransformerMixin):
    def __init__(self, cols: list[str]):
        self.cols = cols

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        return X.drop(self.cols, axis=1)
In [35]:
def gb_pipeline():
    input_normalize_encode = ColumnTransformer(
        transformers=[
            ("scaler", StandardScaler(), ["bmi", "avg_glucose_level", "age"]),
            (
                "categorical",
                make_pipeline(
                    FunctionTransformer(remove_whitespace_values),
                    ce.OneHotEncoder(
                        return_df=True,
                        use_cat_names=True,
                    ),
                ),
                [
                    "gender",
                    "ever_married",
                    "work_type",
                    "heart_disease",
                    "hypertension",
                    "residence_type",
                    "smoking_status",
                ],
            ),
        ],
        remainder="passthrough",
        verbose_feature_names_out=False,
    )

    return Pipeline(
        [
            ("clean", input_normalize_encode),
            ("fill missing", KNNImputer(n_neighbors=3, missing_values=np.nan)),
            (
                "drop_cols",
                ColumnDropper(
                    [
                        "gender_Female",
                        "ever_married_No",
                        "work_type_children",
                        "residence_type_Rural",
                        "smoking_status_never_smoked",
                    ]
                ),
            ),
            ("columns_to_lowercase", FunctionTransformer(columns_to_lowercase)),
            ("convert_types", FunctionTransformer(convert_types)),
        ]
    )

The category_encoders package contains sane encoders that do reasonable things without requiring boilerplate for common tasks. In this case, we strongly value that it keeps the feature names human readable instead of the scikit's default behaviour.

This packaged is used in the ce.OneHotEncoder above. The small prefix makes it hard to spot.

This library was a heaven-sent tool that made this so much easier to debug and inspect.

Preparing all our data¶

In [36]:
gb_clean_pipeline = gb_pipeline()
gb_clean_pipeline
Out[36]:
Pipeline(steps=[('clean',
                 ColumnTransformer(remainder='passthrough',
                                   transformers=[('scaler', StandardScaler(),
                                                  ['bmi', 'avg_glucose_level',
                                                   'age']),
                                                 ('categorical',
                                                  Pipeline(steps=[('functiontransformer',
                                                                   FunctionTransformer(func=<function remove_whitespace_values at 0x7fad2660c820>)),
                                                                  ('onehotencoder',
                                                                   OneHotEncoder(use_cat_names=True))]),
                                                  ['gender', 'ever_marri...
                ('fill missing', KNNImputer(n_neighbors=3)),
                ('drop_cols',
                 ColumnDropper(cols=['gender_Female', 'ever_married_No',
                                     'work_type_children',
                                     'residence_type_Rural',
                                     'smoking_status_never_smoked'])),
                ('columns_to_lowercase',
                 FunctionTransformer(func=<function columns_to_lowercase at 0x7fad2660caf0>)),
                ('convert_types',
                 FunctionTransformer(func=<function convert_types at 0x7fad2660c790>))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('clean',
                 ColumnTransformer(remainder='passthrough',
                                   transformers=[('scaler', StandardScaler(),
                                                  ['bmi', 'avg_glucose_level',
                                                   'age']),
                                                 ('categorical',
                                                  Pipeline(steps=[('functiontransformer',
                                                                   FunctionTransformer(func=<function remove_whitespace_values at 0x7fad2660c820>)),
                                                                  ('onehotencoder',
                                                                   OneHotEncoder(use_cat_names=True))]),
                                                  ['gender', 'ever_marri...
                ('fill missing', KNNImputer(n_neighbors=3)),
                ('drop_cols',
                 ColumnDropper(cols=['gender_Female', 'ever_married_No',
                                     'work_type_children',
                                     'residence_type_Rural',
                                     'smoking_status_never_smoked'])),
                ('columns_to_lowercase',
                 FunctionTransformer(func=<function columns_to_lowercase at 0x7fad2660caf0>)),
                ('convert_types',
                 FunctionTransformer(func=<function convert_types at 0x7fad2660c790>))])
ColumnTransformer(remainder='passthrough',
                  transformers=[('scaler', StandardScaler(),
                                 ['bmi', 'avg_glucose_level', 'age']),
                                ('categorical',
                                 Pipeline(steps=[('functiontransformer',
                                                  FunctionTransformer(func=<function remove_whitespace_values at 0x7fad2660c820>)),
                                                 ('onehotencoder',
                                                  OneHotEncoder(use_cat_names=True))]),
                                 ['gender', 'ever_married', 'work_type',
                                  'heart_disease', 'hypertension',
                                  'residence_type', 'smoking_status'])],
                  verbose_feature_names_out=False)
['bmi', 'avg_glucose_level', 'age']
StandardScaler()
['gender', 'ever_married', 'work_type', 'heart_disease', 'hypertension', 'residence_type', 'smoking_status']
FunctionTransformer(func=<function remove_whitespace_values at 0x7fad2660c820>)
OneHotEncoder(use_cat_names=True)
passthrough
KNNImputer(n_neighbors=3)
ColumnDropper(cols=['gender_Female', 'ever_married_No', 'work_type_children',
                    'residence_type_Rural', 'smoking_status_never_smoked'])
FunctionTransformer(func=<function columns_to_lowercase at 0x7fad2660caf0>)
FunctionTransformer(func=<function convert_types at 0x7fad2660c790>)
In [37]:
missingbmi = split._X_train[split._X_train.bmi.isna()].index
In [38]:
gb_train_clean = gb_clean_pipeline.fit_transform(split._X_train)
gb_train_clean.head()
Out[38]:
bmi avg_glucose_level age gender_male ever_married_yes work_type_private work_type_self-employed work_type_govt_job heart_disease hypertension residence_type_urban smoking_status_formerly_smoked smoking_status_unknown smoking_status_smokes
id
58761 0.236179 -0.408857 0.387306 1.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0
48368 1.067775 -0.039481 0.962764 0.0 1.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
8168 -0.701016 0.144765 -0.409483 0.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
44426 -0.239018 0.450220 -0.984942 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0
24783 -0.793416 -0.400010 -0.675079 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0
In [39]:
gb_train_clean.loc[missingbmi]
Out[39]:
bmi avg_glucose_level age gender_male ever_married_yes work_type_private work_type_self-employed work_type_govt_job heart_disease hypertension residence_type_urban smoking_status_formerly_smoked smoking_status_unknown smoking_status_smokes
id
5984 -0.375418 -0.612788 -0.807878 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
7859 0.200979 -0.149630 -0.409483 1.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0
39202 1.146975 2.910444 0.785700 0.0 1.0 1.0 0.0 0.0 1.0 1.0 1.0 1.0 0.0 0.0
2549 0.799377 -0.503524 -1.162006 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
967 0.077780 -0.392047 0.785700 1.0 1.0 1.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1678 0.231779 -0.160468 0.475838 0.0 1.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
20112 -0.375418 2.375180 1.582489 1.0 1.0 1.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 0.0
69768 -1.259814 -0.787966 -1.856098 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0
8968 0.319779 2.257510 -0.055355 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0
49894 0.940176 2.223669 1.538223 0.0 1.0 1.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0

122 rows × 14 columns

Looking at the split of missing BMIs over stroke/no-stroke will help us make better informed decisions later (see summary)

Check distribution of NAs in BMI¶

In [40]:
split._y_train.loc[gb_train_clean.loc[missingbmi].index].value_counts()
Out[40]:
0    101
1     21
Name: stroke, dtype: int64
In [41]:
gb_train_clean.bmi.drop(missingbmi)
Out[41]:
id
58761    0.236179
48368    1.067775
8168    -0.701016
44426   -0.239018
24783   -0.793416
           ...   
65962    1.344974
19584   -1.189414
66051    0.407778
58477    1.094175
64849    0.724577
Name: bmi, Length: 2929, dtype: float64
In [42]:
@run
def plot_bmis():
    originalbmi_df = gb_train_clean.bmi.drop(missingbmi)
    missingbmi_df = gb_train_clean.bmi.loc[missingbmi]
    plt.figure(figsize=(6, 4))
    j = sns.jointplot(space=0, ratio=2)

    j.ax_marg_x.set_visible(False)

    sns.scatterplot(originalbmi_df, color="lightgrey", ax=j.ax_joint)
    sns.scatterplot(missingbmi_df, color="orange", ax=j.ax_joint)
    sns.histplot(y=originalbmi_df, color="lightgrey", ax=j.ax_marg_y)
    sns.histplot(y=missingbmi_df, color="orange", ax=j.ax_marg_y)

    plt.title("Visualizing imputed BMI for missing values (KNN=3)")
    plt.legend(["original", "imputed"])
    plt.tight_layout()
<Figure size 600x400 with 0 Axes>

We can see that the distribution of imputed values seems to follow the general distribution of the overall population Since the only transformation we applied to this was scaling, the result should be the same, whether we input before or after scaling.

Let's check that the data still makes sense after all the transformations/cleanup

In [43]:
@run
def training_data_after_transformation():
    data = gb_train_clean.join(split._y_train, how="inner")
    f, ax = plt.subplots(2, 7, figsize=(24, 6))
    for i, c in enumerate(set(data.columns) - {"stroke"}):
        axx = ax[i // 7, i % 7]
        sns.histplot(data, x=c, hue="stroke", multiple="stack", ax=axx)
        chart_utils.rotate_x_labels(axx, rotation=80)

    plt.tight_layout()

Everything looks good:

  • [x] values are rescaled
  • [x] all categorical values are encoded properly
  • [x] all infrequent values (gender:none, neverworked, etc..) are no longer included
  • etc..

With this, we're ready to start creating models using XGBoost and LightGBM

Predicting Stroke Risk using boosted models¶

XGBoost¶

Let's try XGBoost first, as it seems to be generally regarded as a good starting point, with marginally better performance, even though it is sometimes slower.

In [44]:
def conf_matrix_stroke(expected, predicted, normalize="true", ax=None):
    return ConfusionMatrixDisplay.from_predictions(
        expected,
        predicted,
        labels=[0, 1],
        display_labels=["no stroke", "stroke"],
        normalize=normalize,
        ax=ax,
    )
In [ ]:
 
In [45]:
# for imbalanced data, see xgboost docs:
# https://xgboost.readthedocs.io/en/stable/tutorials/param_tuning.html#handle-imbalanced-dataset
eval_metric = "error"
scale_pos_weight = 3  # to be optimized later

xgbclassifier = xgb.XGBClassifier(
    objective="binary:logistic",
    seed=42,
    scale_pos_weight=scale_pos_weight,
    eval_metric=eval_metric,
    subsample=0.9,
    colsample_bytree=0.5,
    scoring="roc_auc",
    verbosity=1,
)

xgbclassifier.fit(X=gb_train_clean, y=split._y_train, verbose=10)
predicted = xgbclassifier.predict(gb_train_clean)

print("Accuracy:", accuracy_score(split._y_train, predicted))
print("Recall:", recall_score(split._y_train, predicted))
conf_matrix_stroke(split._y_train, predicted)
Accuracy: 0.9593575876761717
Recall: 0.35570469798657717
Out[45]:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fad26592580>

We are normalizing the confusion matrix over the true conditions (instead of the "prediction" condition), because we want to see the proportion of correctly classified instances for each class. This is particularly useful for classifications with imbalanced datasets: normalizing over the true conditions can help us surface problems with classification.

This is because false negatives are particularly problematic and dangerous for the situation at hand (not detecting at-risk patients)

Understanding the loss function¶

Let's plot the performance of the function so we can compare how it changes, according to the learning_rate

In [46]:
clean_eval = gb_clean_pipeline.transform(split._X_val)


def xgb_with_learning_rate(
    lr: float,
    early_stop: int = None,
    predict: bool = False,
    classifier=None,
) -> list[float]:
    """
    uses the provide classifier, or creates a new one with specific defaults
    if predict is set, it returns the predictions (for further processing),
    if predict is not set, the evaluation results are preprocessed and returned.
    """
    if not classifier:
        classifier = xgb.XGBClassifier(
            learning_rate=lr,
            n_estimators=1000,
            objective="binary:logistic",
            seed=42,
            scale_pos_weight=scale_pos_weight,
            subsample=0.9,
            colsample_bytree=0.5,
            scoring="roc_auc",
        )
    classifier.fit(
        gb_train_clean,
        split._y_train,
        eval_metric=["logloss", "aucpr"],
        eval_set=[(clean_eval, split._y_val)],
        verbose=0,
        early_stopping_rounds=early_stop,
    )
    if predict:
        prediction = classifier.predict(clean_eval)
        return prediction, split._y_val, classifier
    else:
        return classifier.evals_result()

Assessing performance using Logloss¶

ChatGPT says:

AUCPR is a measure of how well a classifier separates the positive and negative classes, taking into account the imbalance between the classes. It is calculated by plotting the precision (true positives / (true positives + false positives)) against the recall (true positives / (true positives + false negatives)) at different classification thresholds.

Logloss, measures the uncertainty of the probabilities of the predicted class labels, compared to the true class labels. It takes into account the confidence of the predictions, so a classifier that assigns a high probability to the correct class will have a lower logloss than a classifier that assigns a low probability to the correct class. A lower logloss value indicates better performance.

In summary, when reading the charts, we need to remember 2 things:

  • for AUCPR, higher is better, and Logloss, lower is better.
  • Logloss penalizes our model for "low confidence" predictions much more harshly than for "high confidence" predictions.
In [47]:
def plot_performance(metric_name: str, classifier=None, lrs=[0.005, 0.01, 0.1, 0.5, 1]):
    step = 10
    x = [
        xgb_with_learning_rate(lr, classifier=classifier)["validation_0"][metric_name][
            ::step
        ]
        for lr in lrs
    ]
    df = pd.DataFrame(x, index=lrs)
    sns.lineplot(df.T)
    plt.title(f"{metric_name} over iterations, by learning rate")
    plt.xlabel(f"iterations (x{step})")
    plt.ylim(0)
    plt.ylabel(metric_name)
In [48]:
plot_performance("logloss")
In [49]:
plot_performance("aucpr")

logloss¶

The logloss chart seems to show the same pattern across all cases:

  • starting poorly, approaching 0 at a specific speed (determined by the learning rate), and they starting to increase again due to overfitting.
    • the larger the learning rate, the sooner this spot of optimal performance is reached.
    • early stopping can be used to avoid this, but it was not utilised in this analysis to show how models perform without it.

aucpr / auprc¶

The AUCPR (Area Under the Curve / Precision vs Recall) shows a different picture.

  • ¨The area under the precision-recall curve (AUPRC) is a useful performance metric for imbalanced data in a problem setting where you care a lot about finding the positive examples¨ reference
  • This definitely applies to our dataset: because it is highly imbalanced, and because having false negatives has severe consequences.
  • This chart seems to show a different picture than the one above. They all pretty much stay around the 0.1 - 0.3 range, no matter the number of iterations

PR AUC seems to be the better metric to use for highly imbalanced datasets source

Plotting precision and recall (based on classification threshold)¶

Let's visualize how our model is performing against the validation dataset

In [50]:
predictions, expected, model = xgb_with_learning_rate(0.5, early_stop=10, predict=True)
predicted_proba = model.predict_proba(clean_eval)

charts.plot_precision_recall_over_threshold(split._y_val, predicted_proba)
In [51]:
PrecisionRecallDisplay.from_predictions(expected, predicted_proba[:, 1])
plt.legend(loc="upper right")
plt.ylim(0)
Out[51]:
(0.0, 1.0475442043222003)

The performance of this "out of the box" model is far from ideal:

  • Terrible PR curve: Area under the curve is very small.
  • Terrible confusion matrix: 65% of the stroke cases are not detected! We're indirectly killing more patients through inaction than we are saving! Almost 2x as many, in fact!

Our goal over the rest of the exercise is to improve these metrics by tuning better models.

Hyperparameter tuning XGBoost¶

We want to find ideal XGBoost hyperparameters that help our model learn our data better, to improve our predictive scores.

These seem to be common ranges for XGB hyperparameters: Source

  • max_depth: [3, 4, 5, 6, 7, 8, 9, 10]
  • min_child_weight: [1, 3, 5, 7]
  • gamma: [0.0, 0.1, 0.2, 0.3, 0.4]
  • subsample: [0.6, 0.7, 0.8, 0.9]
  • colsample_bytree: [0.6, 0.7, 0.8]
  • reg_alpha: [1e-5, 1e-2, 0.1, 1]
  • learning_rate: [0.01, 0.05, 0.1]

We will use a subset and only expand it if needed

In [52]:
def get_hyperparams_for(X, y):
    hyperparams_space = {
        "max_depth": [1, 2, 3, 4],
        "gamma": [0, 1, 3],
        "learning_rate": [0.001, 0.01, 0.1],
        "n_estimators": [1000],
        "scale_pos_weight": [1, 10, 20],
    }

    classifier = xgb.XGBClassifier(
        objective="binary:logistic",
        seed=42,
    )

    cv = GridSearchCV(
        estimator=classifier,
        param_grid=hyperparams_space,
        cv=3,
        scoring="roc_auc",
        error_score="raise",
        n_jobs=-1,
        verbose=1,
    )
    print(X.shape)
    print(y.shape)

    cv.fit(
        X=X,
        y=y,
        verbose=False,
    )
    return cv
In [53]:
@cached_with_pickle()
def get_best_hyperparameters():
    return get_hyperparams_for(gb_train_clean, split._y_train)


cv = get_best_hyperparameters()
Loading from cache [./cached/pickle/get_best_hyperparameters.pickle]
In [54]:
cv.best_params_
Out[54]:
{'colsample_bytree': 0.8,
 'gamma': 0,
 'learning_rate': 0.001,
 'max_depth': 3,
 'n_estimators': 100,
 'reg_alpha': 0.1,
 'scale_pos_weight': 5,
 'subsample': 0.7}

Building the model using the optimized hyperparameters:

In [55]:
params = cv.best_params_
classifier = xgb.XGBClassifier(
    objective="binary:logistic",
    seed=42,
    n_jobs=-1,
    scoring="roc_auc",
    eval_metric="error",
    **params
)

Let's compare the "optimized" vs "out of the box" models, to make sure they are properly configured

In [56]:
compare = pd.DataFrame(
    [classifier.get_params(), xgbclassifier.get_params()],
    index=["custom", "default"],
).T
compare["equals"] = compare["custom"] == compare["default"]
compare.sort_values(by="equals")
Out[56]:
custom default equals
n_jobs -1 1 False
subsample 0.7 0.9 False
silent None None False
colsample_bytree 0.8 0.5 False
learning_rate 0.001 0.1 False
scale_pos_weight 5 3 False
reg_alpha 0.1 0.0 False
missing None None False
nthread None None False
verbosity 1 1 True
seed 42 42 True
reg_lambda 1 1 True
random_state 0 0 True
objective binary:logistic binary:logistic True
base_score 0.5 0.5 True
n_estimators 100 100 True
min_child_weight 1 1 True
max_depth 3 3 True
max_delta_step 0 0 True
gamma 0 0 True
colsample_bynode 1 1 True
colsample_bylevel 1 1 True
booster gbtree gbtree True
scoring roc_auc roc_auc True
eval_metric error error True
In [57]:
classifier.fit(X=gb_train_clean, y=split._y_train, verbose=10)
Out[57]:
XGBClassifier(colsample_bytree=0.8, eval_metric='error', learning_rate=0.001,
              n_jobs=-1, reg_alpha=0.1, scale_pos_weight=5, scoring='roc_auc',
              seed=42, subsample=0.7)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
XGBClassifier(colsample_bytree=0.8, eval_metric='error', learning_rate=0.001,
              n_jobs=-1, reg_alpha=0.1, scale_pos_weight=5, scoring='roc_auc',
              seed=42, subsample=0.7)
In [58]:
clean_train = gb_clean_pipeline.transform(split._X_train)
clean_val = gb_clean_pipeline.transform(split._X_val)
# clean_test = gb_clean_pipeline.transform(split._X_test)

predicted_train = classifier.predict(clean_train)
predicted_val = classifier.predict(clean_val)
# predicted_test = classifier.predict(clean_test)
In [59]:
conf_matrix_stroke(split._y_train, predicted_train)
Out[59]:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fad26592af0>
In [60]:
conf_matrix_stroke(split._y_val, predicted_val)
Out[60]:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fad8dc1f1c0>

So... this is far from excellent. On the plus side, we see that there's no overfitting and that it behaves as expected (poorly) given the high imbalance we have in our data.

Improving performance by addressing the imbalanced data¶

Our model is not performing ideally because of the highly umbalanced data we have (lots of people with no-stroke, and very few with stroke!)

Oversampling using SMOTE¶

In [61]:
from sklearn.utils import resample
In [62]:
smote = SMOTE()
X_resampled, y_resampled = smote.fit_resample(clean_train, split._y_train)
In [63]:
X_resampled.shape, y_resampled.shape
Out[63]:
((5804, 14), (5804,))
In [64]:
@cached_with_pickle()
def hyperparameters_oversampling():
    return get_hyperparams_for(X_resampled, y_resampled)


cv = hyperparameters_oversampling()
Loading from cache [./cached/pickle/hyperparameters_oversampling.pickle]
In [65]:
cv.best_params_
Out[65]:
{'gamma': 0,
 'learning_rate': 0.1,
 'max_depth': 4,
 'n_estimators': 1000,
 'scale_pos_weight': 20}
In [66]:
params = cv.best_params_
classifier_resampled = xgb.XGBClassifier(
    objective="binary:logistic",
    seed=42,
    n_jobs=-1,
    scoring="roc_auc",
    eval_metric="error",
    **params
)
In [67]:
classifier_resampled.fit(X=X_resampled, y=y_resampled, verbose=10)
predicted_resampled = classifier_resampled.predict(X_resampled)
conf_matrix_stroke(y_resampled, predicted_resampled)
Out[67]:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fad93fc5d90>
In [68]:
predicted_resampled_val = classifier_resampled.predict(clean_val)
conf_matrix_stroke(split._y_val, predicted_resampled_val)
Out[68]:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fad93e28ca0>

Overfitting warning! It seems that resampling using SMOTE did not help because it resulted it overfitting on the stroke label.

Clearly oversampling did not work as it resulted in extreme overfitting.

It is possible (speculation!) that it could be caused because SMOTE uses linear interpolation, which might be "learned" too easily by XGBoost, due to its simplicity! This along with the low number of true cases, could make overfitting too tempting for XGBoost.

Subsampling¶

Let's try the opposite approach, so we can compare the performances.

In [69]:
def subsample(X, y, fieldname: str):
    clean_df = X.join(y, how="inner")
    num = clean_df[fieldname].value_counts().min()
    balanced_subsample = (
        clean_df.groupby(fieldname)
        .apply(lambda x: x.sample(num, random_state=42))
        .reset_index(level=0, drop=True)
    )
    return balanced_subsample
In [70]:
balanced_subsample = subsample(clean_train, split._y_train, "stroke")
balanced_subsample.stroke.value_counts()
Out[70]:
0    149
1    149
Name: stroke, dtype: int64
In [71]:
balanced_subsample
Out[71]:
bmi avg_glucose_level age gender_male ever_married_yes work_type_private work_type_self-employed work_type_govt_job heart_disease hypertension residence_type_urban smoking_status_formerly_smoked smoking_status_unknown smoking_status_smokes stroke
id
17175 -1.123414 -0.550415 -1.250538 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0
48129 -0.410618 -0.573196 0.564370 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0
56976 1.318574 -0.220851 -0.055355 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0
31308 0.948976 0.188117 0.254508 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0
27446 -1.743811 -0.656583 -1.560400 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
712 -0.291818 -0.485829 1.715287 0.0 0.0 1.0 0.0 0.0 1.0 1.0 0.0 1.0 0.0 0.0 1
2326 -0.080619 1.617405 1.051296 0.0 1.0 1.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 1
1210 1.397774 2.323865 1.095563 0.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1
58202 0.288979 1.358399 0.298774 0.0 1.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1
29552 1.490173 2.309267 0.520104 0.0 1.0 1.0 0.0 0.0 1.0 1.0 1.0 0.0 0.0 1.0 1

298 rows × 15 columns

In [72]:
X_train_balanced = balanced_subsample.drop("stroke", axis=1)
y_train_balanced = balanced_subsample.stroke

model = xgb.XGBClassifier()

model.fit(X_train_balanced, y_train_balanced)

y_pred = model.predict(clean_val)

print("Accuracy:", accuracy_score(split._y_val, y_pred))
print("Recall:", recall_score(split._y_val, y_pred))
conf_matrix_stroke(split._y_val, y_pred)
__
Accuracy: 0.7043222003929273
Recall: 0.84

This performance is already much better, out of the box, just from subsampling.

We can expand with other improvements:

Subsampling + Hyperparameter tuning¶

In [73]:
@cached_with_pickle()
def hyperparameters_xgboost():
    return get_hyperparams_for(X_train_balanced, y_train_balanced)
In [74]:
cv_subs = hyperparameters_xgboost()
Loading from cache [./cached/pickle/hyperparameters_xgboost.pickle]
In [75]:
cv_subs.best_params_
Out[75]:
{'gamma': 3,
 'learning_rate': 0.01,
 'max_depth': 2,
 'n_estimators': 1000,
 'scale_pos_weight': 1}

As expected scale_pos_weight is 1, indicating that, of all possible values, treating the dataset as balanced gives the best performance. CV sees the same thing we are seeing.

In [76]:
subsample_classifier = xgb.XGBClassifier(**cv_subs.best_params_)
subsample_classifier.fit(X_train_balanced, y_train_balanced)

y_pred = subsample_classifier.predict(clean_val)

print("Accuracy:", accuracy_score(split._y_val, y_pred))
print("Recall:", recall_score(split._y_val, y_pred))

f, ax = plt.subplots(1, 2, figsize=(10, 4))

conf_matrix_stroke(split._y_val, y_pred, normalize="true", ax=ax[0])
conf_matrix_stroke(split._y_val, y_pred, normalize=None, ax=ax[1])
ax[0].set_title("normalized across 'true' label")
ax[1].set_title("raw numbers")
plt.tight_layout()
Accuracy: 0.6463654223968566
Recall: 0.94

That's a pretty impressive boost in performance, only missing 3 patients at risk of stroke, out of 1018 patients assessed.

The small drop in the true negative rate is not of concern, as we can administer a second rounf of testing, in case of false positives, which is much better (less costly mistake = few losses of human lives) than missing a true positive.

Considering how well subsampling worked, an improvement would be to completely skip the KNN step in the data cleaning pipeline to fill in missing BMIs.

In [77]:
sns.barplot(
    y=X_train_balanced.columns,
    x=subsample_classifier.feature_importances_,
    orient="h",
    color="lightgrey",
)
Out[77]:
<AxesSubplot: >

LightGBM¶

Let's try another model to see if we can get even better performance.

The mark to beat is 0.94 accuracy.

Given than LightGBM is pretty similar to XGBoost, we will only explore this model in a lighter fashion.

We just want to compare the two models in 2 different states:

  • unbalanced + out-of-the-box performance (low maintenance mode)
  • subsampled balanced + "hyper-parameter tuned" performance
In [78]:
lgb_classifier = lgbm.LGBMClassifier()
In [79]:
lgb_classifier.get_params()
Out[79]:
{'boosting_type': 'gbdt',
 'class_weight': None,
 'colsample_bytree': 1.0,
 'importance_type': 'split',
 'learning_rate': 0.1,
 'max_depth': -1,
 'min_child_samples': 20,
 'min_child_weight': 0.001,
 'min_split_gain': 0.0,
 'n_estimators': 100,
 'n_jobs': -1,
 'num_leaves': 31,
 'objective': None,
 'random_state': None,
 'reg_alpha': 0.0,
 'reg_lambda': 0.0,
 'silent': 'warn',
 'subsample': 1.0,
 'subsample_for_bin': 200000,
 'subsample_freq': 0}
In [80]:
hyperparameter_space = {
    "num_leaves": [7, 15, 31, 63],
    "learning_rate": [0.01, 0.05, 0.1, 0.2],
    "n_estimators": [50, 100, 200, 400, 1600],
    "min_child_samples": [10, 20, 30],
    "subsample": [0.6, 0.8, 1.0],
    "colsample_bytree": [0.6, 0.8, 1.0],
}

Comparing "Out of the box" performance (XGBoost vs LightGBM)¶

We want to compare this with the XGBoost out-of-the-box performance. This means:

  • unbalanced dataset
  • default parameters
In [81]:
lgb_classifier.fit(gb_train_clean, split._y_train)
y_pred = lgb_classifier.predict(clean_val)

print("Accuracy:", accuracy_score(split._y_val, y_pred))
print("Recall:", recall_score(split._y_val, y_pred))
conf_matrix_stroke(split._y_val, y_pred)
Accuracy: 0.9459724950884086
Recall: 0.0
Out[81]:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fad8da1b880>

Alright, it might be super fast to train, but ith has terrible performance.

Comparing their optimal performance (XGBoost vs LightGBM)¶

This means:

  • Subsampled
  • Tuned Hyperparameters

Similarly to the previous example, we're using cv=3 just for a couple of reasons: the little amount of data we have, and the speed boost when we don't need too many folds.

In [82]:
@cached_with_pickle()
def hyperparameters_lightgbm():
    lgbm_cv = GridSearchCV(lgb_classifier, hyperparameter_space, cv=3, n_jobs=-1)
    lgbm_cv.fit(X_train_balanced, y=y_train_balanced)
    return lgbm_cv.best_params_


hyperparams = hyperparameters_lightgbm()
Loading from cache [./cached/pickle/hyperparameters_lightgbm.pickle]
In [83]:
hyperparams
Out[83]:
{'colsample_bytree': 1.0,
 'learning_rate': 0.01,
 'min_child_samples': 10,
 'n_estimators': 100,
 'num_leaves': 31,
 'subsample': 0.6}
In [84]:
lgbmC = lgbm.LGBMClassifier(**hyperparams)
lgbmC.fit(X_train_balanced, y=y_train_balanced)
y_pred = lgbmC.predict(clean_val)

print("Accuracy:", accuracy_score(split._y_val, y_pred))
print("Recall:", recall_score(split._y_val, y_pred))
conf_matrix_stroke(split._y_val, y_pred)
Accuracy: 0.6905697445972495
Recall: 0.78
Out[84]:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fad8d87eaf0>

Not bad. It gets similar performance to the XGBoost, but not as good (worse recall, and worse accuracy).

CatBoost¶

CatBoost is a different beast altogether as it allows us to specify categorical columns, without preprocessing. This means we don't need to encode categorical data in any way (e.g. 1-hot-encoding).

But we have an issue: our original data has some missing data in "BMI" ! For the previous cases, we took advantage of the 1 hot encoding to use KNN as an imputer.

However, right now, we have a machine learning algorithm that allows us to build models without requiring this. Let's make the most of this and not perform any 1 hot encoding. We will just drop those samples that are missing BMI (just so we can try CatBoost without 1-hot-encoded columns).

This is not a trivial decision as it might impact the predictive power. We will still try to compare models, but we should take into account that once the data is dropped, we should no longer consider these "like-with-like" comparisons!

In [85]:
def drop_bmi_nas(df: pd.DataFrame):
    df = df.drop(df[df["bmi"].isna()].index)
    return df
In [86]:
def cat_pipeline():
    input_normalize_encode = ColumnTransformer(
        transformers=[
            ("scaler", StandardScaler(), ["bmi", "avg_glucose_level", "age"]),
        ],
        remainder="passthrough",
        verbose_feature_names_out=False,
    )

    return Pipeline(
        [
            ("clean", input_normalize_encode),
            ("drop_bmi_nas", FunctionTransformer(drop_bmi_nas)),
            ("columns_to_lowercase", FunctionTransformer(columns_to_lowercase)),
            ("convert_types", FunctionTransformer(convert_types)),
        ]
    )
In [87]:
cat_clean_pipeline = cat_pipeline()
cat_train_clean = cat_clean_pipeline.fit_transform(split._X_train)
cat_train_clean.head()
Out[87]:
bmi avg_glucose_level age gender hypertension heart_disease ever_married work_type residence_type smoking_status
id
58761 0.236179 -0.408857 0.387306 Male 0 0 Yes Private Urban formerly smoked
48368 1.067775 -0.039481 0.962764 Female 0 0 Yes Self-employed Rural never smoked
8168 -0.701016 0.144765 -0.409483 Female 0 0 Yes Private Rural formerly smoked
44426 -0.239018 0.450220 -0.984942 Female 0 0 Yes Private Urban never smoked
24783 -0.793416 -0.400010 -0.675079 Female 0 0 No Private Urban formerly smoked
In [88]:
split._X_train.shape, cat_train_clean.shape
Out[88]:
((3051, 10), (2929, 10))

Remember that there were about 100 rows without stroke, and 20 with stroke (see analysis in step 3.3.1 "Check distribution of NAs in BMI")

In [89]:
categorical_colnames = [
    "gender",
    "ever_married",
    "work_type",
    "heart_disease",
    "hypertension",
    "residence_type",
    "smoking_status",
]
cat_features = [cat_train_clean.columns.get_loc(c) for c in categorical_colnames]
cat_features
Out[89]:
[3, 6, 7, 5, 4, 8, 9]

Comparing "Out of the box" performance (XGBoost vs CatBoost)¶

We want to compare this with the XGBoost out-of-the-box performance. This means:

  • unbalanced dataset
  • default parameters
In [90]:
cat = CatBoostClassifier(cat_features=cat_features, verbose=False)
cat.fit(cat_train_clean, y=split._y_train.loc[cat_train_clean.index], verbose=100)
Learning rate set to 0.016301
0:	learn: 0.6622413	total: 52.3ms	remaining: 52.2s
100:	learn: 0.1486501	total: 236ms	remaining: 2.1s
200:	learn: 0.1240673	total: 429ms	remaining: 1.71s
300:	learn: 0.1129590	total: 632ms	remaining: 1.47s
400:	learn: 0.1062494	total: 837ms	remaining: 1.25s
500:	learn: 0.0993700	total: 1.05s	remaining: 1.04s
600:	learn: 0.0906094	total: 1.28s	remaining: 847ms
700:	learn: 0.0838744	total: 1.51s	remaining: 645ms
800:	learn: 0.0774158	total: 1.74s	remaining: 432ms
900:	learn: 0.0712122	total: 1.96s	remaining: 215ms
999:	learn: 0.0658342	total: 2.18s	remaining: 0us
Out[90]:
<catboost.core.CatBoostClassifier at 0x7fad8d6a6d60>
In [91]:
cat_val_clean = cat_clean_pipeline.transform(split._X_val)
expected = split._y_val.loc[cat_val_clean.index]
y_pred = cat.predict(cat_val_clean)

print("Accuracy:", accuracy_score(expected, y_pred))
print("Recall:", recall_score(expected, y_pred))
conf_matrix_stroke(expected, y_pred)
Accuracy: 0.9560776302349336
Recall: 0.0
Out[91]:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fad8d6a6f40>

This is, once more, unsurprisingly bad.

The idea that these models are complex and powerful, but require careful spoonfeeding, is starting to settle.

Let's make things easier for them, by providing them a balanced dataset and ideal hyperparameters.

Comparing their optimal performance (XGBoost vs CatBoost)¶

This means:

  • Subsampled
  • Tuned Hyperparameters

Similarly to the previous example, we're using cv=3 just for a couple of reasons: the little amount of data we have, and the speed boost when we don't need too many folds.

In [92]:
cat_subsampled = subsample(
    cat_train_clean,
    split._y_train.loc[cat_train_clean.index],
    "stroke",
)
cat_subsampled.stroke.value_counts()
Out[92]:
0    128
1    128
Name: stroke, dtype: int64
In [93]:
cat_X = cat_subsampled.drop("stroke", axis=1)
cat_y = cat_subsampled.stroke
In [94]:
@cached_with_pickle()
def hyperparameters_catboost():
    param_grid = {
        "iterations": [5, 10, 20, 50, 100, 200],
        "learning_rate": [0.001, 0.01, 0.04],
        "depth": [6, 7, 8, 9, 10],
        "l2_leaf_reg": [1, 2, 3],
    }
    catcv = GridSearchCV(estimator=cat, param_grid=param_grid, verbose=0)
    catcv.fit(cat_X, cat_y, eval_set=[(cat_val_clean, expected)])
    return catcv


catcv = hyperparameters_catboost()
Loading from cache [./cached/pickle/hyperparameters_catboost.pickle]
In [95]:
catcv.best_params_
Out[95]:
{'depth': 10, 'iterations': 20, 'l2_leaf_reg': 1, 'learning_rate': 0.04}

Making predictions¶

In [96]:
cat = CatBoostClassifier(**catcv.best_params_, cat_features=cat_features)
cat.fit(cat_X, cat_y)
predicted = cat.predict(cat_val_clean)

print("Accuracy:", accuracy_score(expected, predicted))
print("Recall:", recall_score(expected, predicted))


f, ax = plt.subplots(1, 2, figsize=(10, 4))

conf_matrix_stroke(expected, predicted, normalize="true", ax=ax[0])
conf_matrix_stroke(expected, predicted, normalize=None, ax=ax[1])
ax[0].set_title("normalized across 'true' label")
ax[1].set_title("raw numbers")
plt.tight_layout()
0:	learn: 0.6809270	total: 7.42ms	remaining: 141ms
1:	learn: 0.6697172	total: 14.6ms	remaining: 131ms
2:	learn: 0.6573935	total: 20.9ms	remaining: 118ms
3:	learn: 0.6459018	total: 26.5ms	remaining: 106ms
4:	learn: 0.6344832	total: 33.1ms	remaining: 99.5ms
5:	learn: 0.6261632	total: 39.2ms	remaining: 91.5ms
6:	learn: 0.6175634	total: 44.3ms	remaining: 82.3ms
7:	learn: 0.6073765	total: 51.7ms	remaining: 77.6ms
8:	learn: 0.5966834	total: 58.4ms	remaining: 71.4ms
9:	learn: 0.5883944	total: 63.7ms	remaining: 63.7ms
10:	learn: 0.5790567	total: 69.3ms	remaining: 56.7ms
11:	learn: 0.5706685	total: 74.2ms	remaining: 49.5ms
12:	learn: 0.5619637	total: 80ms	remaining: 43.1ms
13:	learn: 0.5545466	total: 85ms	remaining: 36.4ms
14:	learn: 0.5470353	total: 88.9ms	remaining: 29.6ms
15:	learn: 0.5398266	total: 93.7ms	remaining: 23.4ms
16:	learn: 0.5337013	total: 94.8ms	remaining: 16.7ms
17:	learn: 0.5274122	total: 100ms	remaining: 11.2ms
18:	learn: 0.5215208	total: 106ms	remaining: 5.56ms
19:	learn: 0.5151272	total: 111ms	remaining: 0us
Accuracy: 0.6465781409601634
Recall: 0.9024390243902439

This score is very similar to the previous ones we've seen.

Before we close off this exploration and go on to select our winner, there's one last thing we wanted to do with CatBoost.

Understanding Feature importance¶

CatBoost allows us to introspect the models/estimators and understand how they are fitting our dataset.

One of the ways is to use feature_importance, which lets us understand how each feature contributes towards the final prediction.

In [97]:
def plot_feature_importance_from_cv(cv, colnames):
    feat_importance = pd.DataFrame(
        {"importance": cv.estimator.get_feature_importance()},
        index=colnames,
    )
    sns.barplot(
        feat_importance.sort_values(by="importance", ascending=False).T,
        orient="h",
        color="lightgrey",
    )
In [98]:
plot_feature_importance_from_cv(catcv, cat_train_clean.columns)

Choosing our best model¶

Normally we'd have a specific target threshold that our model has to hit, before we can consider using it in production. This threshold would be decided with business experts and stakeholder, a nd would help us make a go/no-go decision on whether a particular is worth keeping/improving.

But in this exercise, we must pick one for the rest of the tasks. So we will just compare our models against previously unseen data (the test split), and we will pick the best performing one, even if its performance is not excellent.

I hope at least one of them is not embarassingly bad, but considering the small number of datapoints we used to train them, I would not be surprised if they all performed poorly.

XGBoost on Test Data¶

In [99]:
gb_clean_test_X = gb_clean_pipeline.transform(split._X_test)
gb_clean_test_y = split._y_test.loc[gb_clean_test_X.index]

test_predicted = subsample_classifier.predict(gb_clean_test_X)
test_expected = gb_clean_test_y

print("Accuracy:", accuracy_score(test_expected, test_predicted))
print("Recall:", recall_score(test_expected, test_predicted))

f, ax = plt.subplots(1, 2, figsize=(10, 4))


conf_matrix_stroke(test_expected, test_predicted, normalize="true", ax=ax[0])
conf_matrix_stroke(test_expected, test_predicted, normalize=None, ax=ax[1])
ax[0].set_title("normalized across 'true' label")
ax[1].set_title("raw numbers")
plt.tight_layout()

xgboost_predicted = test_predicted
xgboost_expected = test_expected
Accuracy: 0.656188605108055
Recall: 0.78

LightGBM on Test Data¶

In [100]:
gb_clean_test_X = gb_clean_pipeline.transform(split._X_test)
gb_clean_test_y = split._y_test.loc[gb_clean_test_X.index]

test_predicted = lgbmC.predict(gb_clean_test_X)
test_expected = gb_clean_test_y

print("Accuracy:", accuracy_score(test_expected, test_predicted))
print("Recall:", recall_score(test_expected, test_predicted))

f, ax = plt.subplots(1, 2, figsize=(10, 4))

conf_matrix_stroke(test_expected, test_predicted, normalize="true", ax=ax[0])
conf_matrix_stroke(test_expected, test_predicted, normalize=None, ax=ax[1])
ax[0].set_title("normalized across 'true' label")
ax[1].set_title("raw numbers")
plt.tight_layout()

lgbm_predicted = test_predicted
lgbm_expected = test_expected
Accuracy: 0.6964636542239686
Recall: 0.68

CatBoost on Test Data¶

In [101]:
cat_clean_test_X = cat_clean_pipeline.transform(split._X_test)
cat_clean_test_y = split._y_test.loc[cat_clean_test_X.index]

test_predicted = cat.predict(cat_clean_test_X)
test_expected = cat_clean_test_y

print("Accuracy:", accuracy_score(test_expected, test_predicted))
print("Recall:", recall_score(test_expected, test_predicted))

f, ax = plt.subplots(1, 2, figsize=(10, 4))

conf_matrix_stroke(test_expected, test_predicted, normalize="true", ax=ax[0])
conf_matrix_stroke(test_expected, test_predicted, normalize=None, ax=ax[1])
ax[0].set_title("normalized across 'true' label")
ax[1].set_title("raw numbers")
plt.tight_layout()

cat_predicted = test_predicted
cat_expected = test_expected
Accuracy: 0.656441717791411
Recall: 0.725

Comparing performance across models¶

In order of performance (better means smaller false-negative-rate), these are the models:

  1. XGBoost, with 0.22 FNR
  2. CatBoost, with 0.28 FNR
  3. LightGBM, with 0.32 FNR
In [102]:
f, ax = plt.subplots(1, 3, figsize=(12, 4))

conf_matrix_stroke(xgboost_expected, xgboost_predicted, normalize="true", ax=ax[0])
conf_matrix_stroke(lgbm_expected, lgbm_predicted, normalize="true", ax=ax[1])
conf_matrix_stroke(cat_expected, cat_predicted, normalize="true", ax=ax[2])
ax[0].set_title("XGBoost")
ax[1].set_title("LightGBM")
ax[2].set_title("CatBoost")
plt.tight_layout()
In [103]:
xgb_predicted_proba = subsample_classifier.predict_proba(gb_clean_test_X)
lgbm_predicted_proba = lgbmC.predict_proba(gb_clean_test_X)
cat_predicted_proba = cat.predict_proba(cat_clean_test_X)
In [104]:
titles = ["XGBoost", "LightGBM", "CatBoost"]
expecteds = [xgboost_expected, lgbm_expected, cat_expected]
recalls = [xgb_predicted_proba, lgbm_predicted_proba, cat_predicted_proba]


def plot_pr(ax, row, col, i):
    charts.plot_precision_recall_over_threshold(expecteds[i], recalls[i], ax=ax)
    ax.set_title(titles[i])


charts.grid(1, 3, plot_pr, figsize=(12, 4))
__

XGBoost is not only more nuanced (broader spectrum in the thresholds), but also clearly better performing than the other two.

We will use this model for our deployed application

Exporting the pipeline/model for future reuse¶

We will persist the best model and use it to provide advice to our patients

In [105]:
joblib.dump(subsample_classifier, "streamlit/xgboost_predictive_model.gz")
joblib.dump(gb_clean_pipeline, "streamlit/xgboost_cleaning_pipeline.gz")
Out[105]:
['streamlit/xgboost_cleaning_pipeline.gz']

Hosted Application¶

These models will be used in the deployed application to provide realtime predictions of stroke risk.

You can see the code in the /streamlit folder in the root of the repository

Model Explainability¶

One of the most valuable tools we have to help our patients understand the factors that increase/decrease their risk of stroke are visualization tools.

One of these visualization tools is the SHAP library, which uses Shapley values to show the contribution of each of the patient's attributes to the predicted score.

Let's pick a random patient from one of our dataset

In [106]:
display(gb_clean_test_X.head(1))
display(split.test_df().loc[[4062]])
bmi avg_glucose_level age gender_male ever_married_yes work_type_private work_type_self-employed work_type_govt_job heart_disease hypertension residence_type_urban smoking_status_formerly_smoked smoking_status_unknown smoking_status_smokes
id
4062 0.412178 2.925705 1.272627 1.0 1.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0
gender age hypertension heart_disease ever_married work_type residence_type avg_glucose_level bmi smoking_status stroke
id
4062 Male 72.0 0 1 Yes Private Rural 238.27 NaN smokes 0
In [109]:
explainer = shap.TreeExplainer(subsample_classifier)
shap_values = explainer.shap_values(gb_clean_test_X)

target = 0
shap.initjs()
shap.force_plot(
    explainer.expected_value,
    shap_values[target, :],
    feature_names=gb_clean_test_X.columns,
)
Out[109]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

This chart shows how the model predicts their stroke risk, and how each of the factors influences.

In order of strongest contributors we have:

  • age (cannot be changed)
  • glucose level (can be changed)
  • bmi (can be changed, not significanty, and contributes very little)

Given these, we could present this chart to the patient and explain which specific actions they could take to help reduce their risk.

(This is obviously a made up scenario... we should not blindly present this to any patient, especially if no business expert has review this model's "understanding" of the data... but given the constraints of the exercise, I was looking for a semi-realistic scenario that could help understanding how to use Shapley values and force charts)

Conclusions¶

This project analyses the stroke-risk dataset (Kaggle). This dataset is fairly flat and easy to understand/use.

The data is highly unbalanced, requiring us to perform extreme subsampling in order to obtain models with a reasonable performance.

Executive Summary¶

Chosen metrics to minimize¶

Given that this data is used to predict the risk of a life-altering medical condition, we wanted to tune our models to minimize false negatives. The risk of false positives is not as impactful, as they can be mitigated by doing further tests at a later date.

Risk factors¶

The primary factors that influence risk of stroke are:

  • age
  • bmi
  • glucose level
  • smoking
  • work type

Explainability¶

We have explored various tools to understand how our model works behind the scenes. These tools include:

  • An interactive web app that patients can use to self assess
  • Individualized Force Plots that use Shapley values to visualize the contribution of each factor to a particular patient's prediction.

Technical Symmary¶

Models change a lot based on initial seed used for splitting train/val/test. Shows brittle models that change depending on the initial split.

Since the split depends on the seed, and the performance changes wildly, we can conclude that, while the models seem to work well for this specific chosen value, other values would perform worse.

  • The bad news: This is a type of overfitting
  • The good news: It can be solved by adding more data

Suggested solution: Let's close off this initial MVP, call it a success, and proceed to train better models with more data from our hospital.

Stakeholder Summary¶

We are on the brink of a world changing revolution

Also: We need more funding!